[Manual] Rcpp使用手册2

这是一篇关于RcppArmadillo的基础性技术文档

Posted by Leung ZhengHua on 2017-11-16

本文总点击量

如果你是第一次阅读这篇文章,以下资源贴可能是你感兴趣的:

RcppArmadillo简明手册
Rcpp相关知识整理
RcppArmadillo官方文档
Armadillo C++ linear algebra library 学习笔记(4)——矩阵的运算
Armadillo矩阵库的使用(二)之API接口


.cpp文件的储存位置

当我们在Rstudio中新建C++文件时,它会直接新建一个文档,方便我们在里面直接编写C++函数。由于Dirk的书《Seamless R and C++ Integration with Rcpp》是2013年出版的,当时Rcpp Attributes这一特性还没有被CRAN批准,所以当时调用和编写Rcpp函数还比较繁琐。我们需要在.R文件中以字符文本的格式写入C++代码,然后调用cxxfunction,加上plugin插件(类似引擎、驱动的东西)对c++代码进行编译。其实在Rstudio的.cpp文件窗口编写c++函数是最方便的,我不会使用C++的IDE,也不会调试C++的代码,我只能在Rstudio里面力求自己写的cpp函数尽量正确,一招print走遍天下。这对我来说实在是一个痛苦的事情。

RcppArmadillo中变量类型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace Rcpp;
using namespace arma;
// [[Rcpp::export]]
colvec mvndrawC(colvec mu, mat sig) {
double k = mu.size();
colvec aux = as<colvec>(rnorm(k));
mat csig = chol(sig).t();
colvec out = mu + csig*aux;
return(out);
}
/*** R
m=c(2,3,4)
s=diag(3)
mvndrawC(m,s)
*/

案例研究

下面的函数目的是求解某种情况下lasso问题的$\beta$系数,这个cpp文件导入RStudio后可以直接source

  • Line 1表示在Rcpp中加载依赖包
  • Line 2-3表示导入两个库的头文件,#include <RcppArmadillo.h>是因为需要使用Armadillo线性代数库才新加入的
  • Line 4-5 导入命名空间及其内部函数,下文的函数由此可以直接使用
  • Line 7 表明Line 7之后首个函数将会被导出到R的工作空间,Line 7之前的函数怎么定义都不会被导出
  • Line 54 fabs表示浮点数类型的取绝对值;abs会强制转换为int类型,导致判断条件错误地为True!
  • Line 70 使用List::create()函数创造列表,并且重命名列表中返回的元素
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
#include <Rcpp.h>
using namespace Rcpp;
using namespace arma;
// [[Rcpp::export]]
List timesTwo(mat x,mat y,double lamb,double bound) {
int p=x.n_cols;
mat A(p,p);
//A=t(x)%*%x+diag(p);
A=x.t()*x+eye<mat>(p,p); //mat A(p,n)=0; wrong!
// v <- matrix(rep(1,p),p)
mat v = ones<mat>(p,1);
// lam<- matrix(rep(lamb,p),p)
mat lam=lamb * ones<mat>(p,1);
// f0 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
// +(1/2)*t(y)%*%y + t(lam)%*%v)
mat f0=((1/2)*v.t()*A*v-v.t()*x.t()*y+1/2*y.t()*y+lam.t()*v);
// aa = (abs(A)+A)/2
mat aa=(abs(A)+A)/2;
// cc = (abs(A)-A)/2
mat cc=(abs(A)-A)/2;
// a=aa %*% v
mat a=aa*v;
// c=cc %*% v
mat c=cc*v;
// b = lam - t(x)%*%y
mat b=lam-x.t()*y;
// v = (-b + sqrt(b^2 +4*a*c))/(2*a)*v
// v.print("before");
v=(-b+sqrt(b%b+4*a%c))/(2*a)%v;
//v.print("after");
// f1 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
// +(1/2)*t(y)%*%y + t(lam)%*%v)
mat f1 =((1/2)*v.t()*A*v-v.t()*x.t()*y+1/2*y.t()*y+lam.t()*v);
// k = 1
int k=1;
// while(abs(f0-f1)>bound){
// f0 <- f1
//
// a=aa %*% v
// c=cc %*% v
// v = (-b + sqrt(b^2 +4*a*c))/(2*a)*v
// f1 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
// +(1/2)*t(y)%*%y + t(lam)%*%v)
//
// k=k+1
// f0-f1
// }
//
while(fabs(f0(0,0)-f1(0,0))>bound) {
f0(0,0)=f1(0,0);
a=aa*v;
c=cc*v;
//(a/10000).print("a"); // must double "
//(b/10000).print("b");
//(c/10000).print("c");
v=(-b+sqrt(b%b+4*a%c))/(2*a)%v;
f1=((1/2)*v.t()*A*v-v.t()*x.t()*y+1/2*y.t()*y+lam.t()*v);
k=k+1;
//printf("%g\n",fabs(f0(0,0)-f1(0,0))); //abs() return int type
}
// v[v <=1e-2] = 0
for(int i=0;i<p;i++) {if(v[i]<0.001){v[i]=0;}}
//return(list(beta = v,k = k))
return List::create(_["beta"]=v,_["k"]=k); //return List::create(v,k);
}
// You can include R code blocks in C++ files processed with sourceCpp
// (useful for testing and development). The R code will be automatically
// run after the compilation.
//=============================================================
/*** R
Update<-function(x,y,lamb,bound){
n <- dim(x)[1];p <- dim(x)[2]
A = t(x)%*%x+diag(p)
v <- matrix(rep(1,p),p)
lam=NULL
lam<- matrix(rep(lamb,p),p)
f0 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
+(1/2)*t(y)%*%y + t(lam)%*%v)
aa = (abs(A)+A)/2
cc = (abs(A)-A)/2
a=aa %*% v
c=cc %*% v
b = lam - t(x)%*%y
v = (-b + sqrt(b^2 +4*a*c))/(2*a)*v
f1 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
+(1/2)*t(y)%*%y + t(lam)%*%v)
k = 1
while(abs(f0-f1)>bound){
f0 <- f1
a=aa %*% v
c=cc %*% v
v = (-b + sqrt(b^2 +4*a*c))/(2*a)*v
f1 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
+(1/2)*t(y)%*%y + t(lam)%*%v)
k=k+1
f0-f1
}
v[v <=1e-2] = 0
return(list(beta = v,k = k))
}
n=1080 #样本容量
x1=1:n
x2=runif(n,-23,2)
x3=runif(n,-20,5)
x4=runif(n,-10,1)
x=as.matrix(data.frame(x1,x2,x3,x4))
#t(x)%*%x
e=rnorm(n,0,5)
y=0.5+5*x1+12*x2+0.9*x4+e
y=as.matrix(y)
model=lm(y~x1+x2+x3+x4)
summary(model)
system.time({model1=timesTwo(x,y,0.5,1e-5)})
system.time({model2=Update(x,y,0.5,1e-5)})
*/

上述文件在source之后,在我尊贵的阿苏斯笔记本下的结果对比为:

system.time({model1=timesTwo(x,y,0.5,1e-5)})
用户 系统 流逝
1.18 0.00 1.18
system.time({model2=Update(x,y,0.5,1e-5)})
用户 系统 流逝
20.73 0.00 21.06

可以发现,利用RcppArmadillo线性代数库优化的速度比R版本快了10倍!